Sensitivity to Smoothing Bandwidth

Set Parameters

state_name <- "ma"



# only include subset of rows for testing 
testing <- FALSE


set.seed(123)


prior_params <- list(
  alpha_mean = .95,
  alpha_sd = 0.08,
  alpha_bounds = NA,
 # alpha_bounds = c(.8,1),
  beta_mean = .15,
  beta_sd =.09,
  beta_bounds = NA,
#  beta_bounds = c(0.002, 0.4),
  s_untested_mean = .03,
  s_untested_sd = .0225,
#  s_untested_bounds = c(0.0018, Inf),
  s_untested_bounds = NA,
  p_s0_pos_mean = .4,
  p_s0_pos_sd = .1225,
 p_s0_pos_bounds = NA,
#  p_s0_pos_bounds = c(.25, .7),
  nsamp = 1e5)

corrected_sample_reps <- 1e3

# relevant for versions 2,3
beta_smoothing_span <- .15
# relevant for versions 3,4
s_untested_smoothing_span <- .2
get_melded <- function(alpha_mean = 0.9,
                       alpha_sd = 0.04,
                       alpha_bounds = NA,
                       beta_mean = .15,
                       beta_sd =.09,
                       beta_bounds = NA,
                       s_untested_mean = .025,
                       s_untested_sd = .0225,
                       s_untested_bounds = NA,
                       p_s0_pos_mean = .4,
                       p_s0_pos_sd = .1225,
                       p_s0_pos_bounds = NA,
                       nsamp = 1e3) {

  given_args <- as.list(environment())
  # cat("Arguments to get_melded:\n")
  # print(given_args)

    theta <- tibble(alpha = sample_gamma_density(nsamp,
                                                mean = alpha_mean,
                                                sd = alpha_sd,
                                                bounds = alpha_bounds),
                    beta= sample_beta_density(nsamp,
                                              mean = beta_mean,
                                              sd = beta_sd,
                                              bounds = beta_bounds),
                    P_S_untested = sample_beta_density(nsamp,
                                                       mean = s_untested_mean,
                                                       sd = s_untested_sd,
                                                       bounds = s_untested_bounds)) %>%
        mutate(phi_induced = est_P_A_testpos(P_S_untested = P_S_untested,
                                             alpha = alpha,
                                             beta=beta))
    
   # message(paste0("nrows of theta: ", nrow(theta)))

    # theta contains values sampled from alpha, beta, P_S_untested, and M(theta) = phi_induced
    # induced phi
    phi <- theta$phi_induced

    # approximate induced distribution via a density approximation
    phi_induced_density <- density(x = phi, n = nsamp, adjust = 2, kernel = "gaussian")

    indexes <- findInterval(phi, phi_induced_density$x)

    phi_sampled_density <- phi_induced_density$y[indexes]

    dp_s0_pos <- function(x) {

      beta_density(x,
                   mean=p_s0_pos_mean,
                   sd = p_s0_pos_sd,
                   bounds=p_s0_pos_bounds)
    }

    weights <- (phi_sampled_density/ dp_s0_pos(phi))^(.5)

    post_samp_ind <-sample.int(n=nsamp,
                               size=nsamp,
                               prob=1/weights,
                               replace=TRUE)

    post_melding <- bind_cols(theta[post_samp_ind,],
                     P_A_testpos =  phi[post_samp_ind]) %>%
        select(-phi_induced)

     return(list(post_melding = post_melding, pre_melding = theta))
  #  return(post_melding)
}


#' reformat for plot generation
reformat_melded <- function(melded_df,
                            theta_df,
                            nsamp,
                            p_s0_pos_mean,
                            p_s0_pos_sd,
                            p_s0_pos_bounds) {

  melded_df_long <- melded_df %>%
    pivot_longer(cols=everything()) %>%
    mutate(type = "After Melding")


  melded <- theta_df %>%
    mutate(P_A_testpos = sample_beta_density(nsamp,
                                             mean = p_s0_pos_mean,
                                             sd = p_s0_pos_sd,
                                             bounds = p_s0_pos_bounds)) %>%
    pivot_longer(cols=everything()) %>%
    mutate(type = ifelse(
      name == "phi_induced",
      "Induced", "Before Melding")) %>%
    mutate(name = ifelse(name == "phi_induced",
                         "P_A_testpos",
                         name)) %>%
    bind_rows(melded_df_long) %>%
    mutate(name = case_when(
      name == "alpha" ~"$\\alpha$",
      name == "beta" ~"$\\beta$",
      name == "P_A_testpos" ~ "$P(S_0|test+,untested)$",
      name == "P_S_untested" ~ "$P(S_1|untested)$")
    ) %>%
    mutate(name = factor(name,
                         levels = c(
                           "$\\alpha$",
                           "$\\beta$",
                           "$P(S_1|untested)$",
                           "$P(S_0|test+,untested)$")))

}



plot_melded <- function(melded, custom_title="", nsamp) {
  
  
  p1 <- melded %>%
    filter(name != "$P(S_0|test+,untested)$") %>%
    ggplot(aes(x = value, fill = type)) +
    geom_density(alpha = .5, show.legend=FALSE) +
    facet_wrap(~name,
               labeller = as_labeller(
                 TeX,   default = label_parsed),
               ncol = 3,
               scales = "fixed") +
    theme_bw() +
    theme(
          # axis.text.y = element_blank(),
          # axis.ticks.y = element_blank(),
          axis.title = element_text(size = 18),
          axis.text.x = element_text(size = 10),
          plot.title =element_text(size = 18,
                                   margin =margin(0,0, .5,0, 'cm')),
          strip.text = element_text(size = 16),
          legend.text = element_text(size = 16)) +
    labs(title = TeX(custom_title,bold=TRUE),
         subtitle =paste0("Number of Samples: ", nsamp),
         fill = "",
         y = "Density") +
    scale_fill_manual(values = c("#5670BF", "#418F6A","#B28542")) +
    guides(fill = guide_legend(keyheight = 2,  keywidth = 2))
  
  p2 <- melded %>%
    filter(name == "$P(S_0|test+,untested)$") %>%
    ggplot(aes(x = value, fill = type)) +
    geom_density(alpha = .5) +
    facet_wrap(~name,
               labeller = as_labeller(
                 TeX,   default = label_parsed),
               ncol = 3,
               scales = "fixed") +
    theme_bw() +
    theme(
          # axis.text.y = element_blank(),
          # axis.ticks.y = element_blank(),
          axis.title = element_text(size = 18),
          axis.text.x = element_text(size = 10),
          plot.title =element_text(size = 18,
                                   margin =margin(0,0, .5,0, 'cm')),
          strip.text = element_text(size = 16),
          legend.text = element_text(size = 18)) +
    labs(
      #title = paste0("Number of Samples: ", nsamp),
         fill = "",
         y = "Density") +
    scale_fill_manual(values = c("#5670BF", "#418F6A","#B28542")) +
    guides(fill = guide_legend(keyheight = 2,  keywidth = 2)) +
    xlim(0,1)
  
  
  p1 / p2 +  plot_layout(nrow =2, widths = c(4,1))
  
}



plot_melded <- function(melded, custom_title="", nsamp) {
  
  
  p1 <- melded %>%
    filter(name != "$P(S_0|test+,untested)$") %>%
    ggplot(aes(x = value, fill = type)) +
    geom_density(alpha = .5, show.legend=FALSE) +
    facet_wrap(~name,
               labeller = as_labeller(
                 TeX,   default = label_parsed),
               ncol = 3,
               scales = "fixed") +
    theme_bw() +
    theme(
          # axis.text.y = element_blank(),
          # axis.ticks.y = element_blank(),
          axis.title = element_text(size = 18),
          axis.text.x = element_text(size = 10),
          plot.title =element_text(size = 18,
                                   margin =margin(0,0, .5,0, 'cm')),
          strip.text = element_text(size = 16),
          legend.text = element_text(size = 16)) +
    labs(title = TeX(custom_title,bold=TRUE),
         subtitle =paste0("Number of Samples: ", nsamp),
         fill = "",
         y = "Density") +
    scale_fill_manual(values = c("#5670BF", "#418F6A","#B28542")) +
    guides(fill = guide_legend(keyheight = 2,  keywidth = 2))
  
  p2 <- melded %>%
    filter(name == "$P(S_0|test+,untested)$") %>%
    ggplot(aes(x = value, fill = type)) +
    geom_density(alpha = .5) +
    facet_wrap(~name,
               labeller = as_labeller(
                 TeX,   default = label_parsed),
               ncol = 3,
               scales = "fixed") +
    theme_bw() +
    theme(
          # axis.text.y = element_blank(),
          # axis.ticks.y = element_blank(),
          axis.title = element_text(size = 18),
          axis.text.x = element_text(size = 10),
          plot.title =element_text(size = 18,
                                   margin =margin(0,0, .5,0, 'cm')),
          strip.text = element_text(size = 16),
          legend.text = element_text(size = 18)) +
    labs(
      #title = paste0("Number of Samples: ", nsamp),
         fill = "",
         y = "Density") +
    scale_fill_manual(values = c("#5670BF", "#418F6A","#B28542")) +
    guides(fill = guide_legend(keyheight = 2,  keywidth = 2)) +
    xlim(0,1)
  
  
  p1 / p2 +  plot_layout(nrow =2, widths = c(4,1))
  
}
library(tidyverse)
library(latex2exp)
library(patchwork)

Implementation of version 2 with multiple smoothing bandwidth choices to see the impact on melded distributions and final estimates.

state_data_path <- paste0(here::here(), 
                          "/data/county_level/",
                          state_name, 
                          "/",
                          state_name, "_county_biweekly.RDS")

covid_county <- readRDS(state_data_path)


source(paste0(here::here(), "/analysis/base_functions/base_functions.R"))

priors_version <- "v2"

fb_symptoms <- readRDS(
  paste0(
    here::here(), 
    "/data/state_level/screeningpos_all_states.RDS")) %>%
  filter(state == state_name)



beta_prior <-  tibble(value = sample_beta_density(1e5,
                          mean = prior_params$beta_mean,
                          sd = prior_params$beta_sd,
                          bounds = prior_params$beta_bounds),
                      type = "Prior on Beta")
#######################################################################
# compare observed distribution across all dates to prior on beta 
########################################################################
fb_symptoms %>% 
  select(signal, date, value, stderr) %>% 
  pivot_wider(names_from = signal,
              values_from = c(value,stderr)) %>%
  mutate(beta_est = value_smoothed_wscreening_tested_positive_14d/
           value_smoothed_wtested_positive_14d)%>%
  select(value = beta_est) %>%
  mutate(type = "Estimate for Beta\n(Screening Test Positivity/Test Positivity)") %>%
  bind_rows(beta_prior) %>%
  ggplot(aes(x=value, fill = type)) +
  geom_density(alpha = .6) +
  theme_bw() +
  labs(title = "Comparing Prior for Beta to Survey Estimates\n from COVID-19 Trends and Impact Survey Data\n Across All Dates",
       subtitle = paste0("State: ", toupper(state_name)),
       fill = "") +
  theme_c()

dates <- readRDS(paste0(here::here(), "/data/date_to_biweek.RDS"))

get_smoothed <- function(input_smoothing_span) {
  
  beta_est <- fb_symptoms %>% 
    select(signal, date, value, stderr) %>% 
    pivot_wider(names_from = signal,
                values_from = c(value,stderr)) %>%
    mutate(beta_estimate = value_smoothed_wscreening_tested_positive_14d
           /value_smoothed_wtested_positive_14d) %>%
    arrange(date) %>%
    mutate(index = 1:nrow(.)) %>%
    ungroup() %>%
    # fill last weeks (missing from survey data) with rolling mean from previous
    # 3 observations
    mutate(rolled_mean = RcppRoll::roll_mean(beta_estimate, n = 3, na.rm = FALSE, fill = NA)) %>%
    fill(rolled_mean, .direction = "down") %>%
    mutate(beta_estimate = ifelse(is.na(beta_estimate), rolled_mean, beta_estimate))
  
  smoothed_beta <- loess(beta_estimate~index, data= beta_est, span = input_smoothing_span)
  
  beta_est %>%
    mutate(beta_estimate_smoothed = predict(smoothed_beta)) %>%
    select(beta_estimate, beta_estimate_smoothed, date) %>%
    mutate(span = input_smoothing_span)
  
}


beta_est <- map_df(seq(.1,.5, by = .05),
                  ~ get_smoothed(input_smoothing_span = .x))        
cust_lab <- function(x) paste0("span: ", x)


###############################################################
# look at survey beta estimates across time 
###############################################################
beta_est  %>%
  ggplot(aes(x=date, y = beta_estimate)) +
  geom_line() +
  geom_point(alpha = .5) +
  geom_line(aes(y=beta_estimate_smoothed), color = "darkred", size = 1.2) +
  facet_wrap(~span, labeller=as_labeller(cust_lab)) +
  theme_c() +
  labs(title = "Estimates of Beta Across Time",
       subtitle = paste0("State: ", toupper(state_name)),
       y = "Estimate of Beta") +
  scale_x_date(date_breaks = "2 months", date_labels = "%b %d")

get_constrained_priors_by_span <- function(input_beta_est) {
  symptoms <- input_beta_est %>%
    select(date, beta_estimate_smoothed) %>%
    left_join(dates) %>%
    group_by(biweek) %>%
    # get last date since observation for date corresponds to value for previous
    # 2 weeks
    slice_max(n=1, order_by = date) %>%
    ungroup() %>%
    select(-date)
  
  covid_county <- covid_county %>% 
    left_join(symptoms) %>%
    # only have CTIS data starting at week 6
    # filter out the beginning dates where beta_estimate_smoothed is NA
    filter(!is.na(beta_estimate_smoothed))

  priors_constrained_by_biweek <- covid_county %>% 
    select(biweek,beta_estimate_smoothed) %>%
    arrange(biweek) %>%
    # there will be more than one observation per county since
    # beta estimates are at the state level
    distinct() %>%
    pmap_df(~ {
      args <- list(...)
     # message(paste0("before: ",prior_params$beta_mean))
      prior_params$beta_mean <- args$beta_estimate_smoothed
     # message(paste0("after: ",prior_params$beta_mean))
      res <- do.call(get_melded, prior_params)
      res$post_melding %>%
        mutate(biweek= args$biweek)
      })
}


constrained_priors_all <- beta_est %>%
  group_by(span) %>%
  group_split() %>%
  map_df(~ { 
    span <- unique(.x$span)
    get_constrained_priors_by_span(input_beta_est = .x) } %>%
      mutate(span = span))

Compare Melded Priors to Original Across Time

melded_original <- readRDS(paste0(
  here::here(),
  "/analysis/results/melded/constrained_v1.RDS"))

# plot melded distributions and compare to original
original <- map_df(6:30, ~{ 
  melded_original %>%
    select(beta) %>%
    mutate(source = "original",
           biweek = .x)})
 

compare <- function(melded, span) {
  melded %>%
    mutate(source = "centered at\nsurvey value") %>%
    select(beta, biweek, source) %>%
    bind_rows(original) %>%
    mutate(biweek = as.factor(biweek)) %>%
    ggplot(aes(x = beta, y=fct_reorder(biweek, 
                                        as.numeric(biweek),
                                        .desc=TRUE), 
               fill = source)) +
    ggridges::geom_density_ridges(alpha = .6) +
    labs(y = "Biweek",
         title = paste0("Comparing Post-melding Distribution Centered\n",
         "at Survey Estimate of Beta\nto Original Prior"),
         subtitle = paste0("Span = ", span),
         fill = "") +
    theme_c(legend.title = element_text(face="bold", size = 16),
            legend.position = "top")
    
}
constrained_priors_all %>%
 # filter(span == .2) %>%
  group_by(span) %>%
  group_split() %>%
  walk(~{
    input_span <- unique(.x$span)
    p <- compare(.x, span = input_span)
    print(p)
  })

covid_county <- covid_county %>%
  filter(biweek >= 6)

# only use a few rows if testing
covid_county <- if(testing) covid_county[1:4,] else covid_county

get_corrected_counts <- function(constrained_priors) {
  span <- unique(constrained_priors$span)
  pmap_df(
  covid_county,
  ~ { 
    input_df <- list(...)
    message(paste0(input_df$biweek, ", " , input_df$fips))
    process_priors_per_county(
      priors = constrained_priors[constrained_priors$biweek == input_df$biweek,],
      county_df = input_df,
      nsamp = prior_params$nsamp) %>%
      generate_corrected_sample(., 
                                num_reps = corrected_sample_reps) %>%
      summarize_corrected_sample()  }) %>%
    mutate(span = span)
  
}


corrected_counts <- constrained_priors_all %>%
 # filter(span == .2) %>%
  group_by(span) %>%
  group_split() %>%
  map_df(~get_corrected_counts(constrained_priors = .x))
fips_sample <- sample(unique(corrected_counts$fips), 1)

corrected_counts %>%
  filter(fips == fips_sample) %>%
  left_join(dates) %>%
  ggplot(aes(x = date)) +
  geom_ribbon(aes(ymin = exp_cases_lb, ymax = exp_cases_ub)) +
  facet_wrap(~span, labeller = as_labeller(cust_lab)) +
  labs(title = paste0("Corrected Counts by Span for FIPS: ", fips_sample)) +
  theme_c() 

fips_sample <- sample(unique(corrected_counts$fips), 1)

corrected_counts %>%
  filter(fips == fips_sample) %>%
  left_join(dates) %>%
  ggplot(aes(x = date)) +
  geom_ribbon(aes(ymin = exp_cases_lb, ymax = exp_cases_ub)) +
  facet_wrap(~span, labeller = as_labeller(cust_lab)) +
  labs(title = paste0("Corrected Counts by Span for FIPS: ", fips_sample)) +
  theme_c() 

fips_sample <- sample(unique(corrected_counts$fips), 1)

corrected_counts %>%
  filter(fips == fips_sample) %>%
  left_join(dates) %>%
  ggplot(aes(x = date)) +
  geom_ribbon(aes(ymin = exp_cases_lb, ymax = exp_cases_ub)) +
  facet_wrap(~span, labeller = as_labeller(cust_lab)) +
  labs(title = paste0("Corrected Counts by Span for FIPS: ", fips_sample)) +
  theme_c()